"""ResNet in PyTorch.
ImageNet-Style ResNet
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
Adapted from: https://github.com/bearpaw/pytorch-classification
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(Bottleneck, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc_reduce = nn.Linear(512, 128)  # New layer to shrink output size

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves
        # like an identity. This improves the model by 0.2~0.3% according to:
        # https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i in range(num_blocks):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, layer=100):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        # out = self.fc_reduce(out)   
        return out

class ResNet50LargeInput(nn.Module):
    def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
        super(ResNet50LargeInput, self).__init__()
        self.in_planes = 64

        # Adjust initial convolution for 224x224 images
        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # ResNet layers
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return x

class ResNet18LargeInput(nn.Module):
    def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
        super(ResNet18LargeInput, self).__init__()
        self.in_planes = 64

        # Adjust initial convolution for 224x224 images
        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return x


def resnet18(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)


def resnet34(**kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)


def resnet50(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)


def resnet101(**kwargs):
    return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)

def resnet18_large_input(**kwargs):
    return ResNet18LargeInput(BasicBlock, [2, 2, 2, 2], **kwargs)

def resnet50_large_input(**kwargs):
    return ResNet50LargeInput(Bottleneck, [3, 4, 6, 3], **kwargs)

model_dict = {
    'resnet18': [resnet18, 512], # Kimia used 128. 
    'resnet34': [resnet34, 512],
    'resnet50': [resnet50, 2048],
    'resnet101': [resnet101, 2048],
    'resnet18_large': [resnet18_large_input, 512],
    'resnet50_large': [resnet50_large_input, 2048],
    'two-layer-nn': [lambda: OneLayerReLUEncoder(), 128],
}


class LinearBatchNorm(nn.Module):
    """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose"""
    def __init__(self, dim, affine=True):
        super(LinearBatchNorm, self).__init__()
        self.dim = dim
        self.bn = nn.BatchNorm2d(dim, affine=affine)

    def forward(self, x):
        x = x.view(-1, self.dim, 1, 1)
        x = self.bn(x)
        x = x.view(-1, self.dim)
        return x

class Identity(nn.Module):
    def forward(self, x):
        return x

class SupConResNet(nn.Module):
    """backbone + projection head"""
    def __init__(self, name='resnet50', head='mlp', feat_dim=128, k=1): # Kimia used 512 for feat_dim
        super(SupConResNet, self).__init__()
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        if head == 'linear':
            self.head = nn.Linear(dim_in, feat_dim)
        elif head == 'mlp':
            self.head = nn.Sequential(
                nn.Linear(dim_in, 512),
                nn.ReLU(inplace=True),
                nn.Linear(512, feat_dim)
            )
        elif head == 'fixed':
            self.head = FixedWeightingHead(dim_in, feat_dim, bias=True, init_scale=k)

        elif head == 'identity':
            self.head = Identity()
        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x):
        feat = self.encoder(x)
        feat = F.normalize(self.head(feat), dim=1)
        return feat

class OneLayerReLUEncoder(nn.Module):
    def __init__(self, input_dim=20, hidden_dim=128):  
        super(OneLayerReLUEncoder, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.out_dim = hidden_dim

    def forward(self, x):
        x = x.view(x.size(0), -1)  # flatten input
        return self.net(x)

class SupCEResNet(nn.Module):
    """encoder + classifier"""
    def __init__(self, name='resnet50', num_classes=10):
        super(SupCEResNet, self).__init__()
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        self.fc = nn.Linear(dim_in, num_classes)

    def forward(self, x):
        return self.fc(self.encoder(x))


class LinearClassifier(nn.Module):
    """Linear classifier"""
    def __init__(self, name='resnet50', num_classes=10):
        super(LinearClassifier, self).__init__()
        _, feat_dim = model_dict[name]
        self.fc = nn.Linear(feat_dim, num_classes)

    def forward(self, features):
        return self.fc(features)

# class FixedWeightingHead(nn.Module):
#     """Fixed weighting projection head"""
#     def __init__(self, dim_in, feat_dim=128, k=1):
#         super(FixedWeightingHead, self).__init__()
#         self.k = k
#         self.dim_in = dim_in
#         self.weights = self._compute_weights(dim_in, k)
#         # self.linear = nn.Linear(dim_in, feat_dim, bias=False)
#         # with torch.no_grad():
#         #     self.linear.weight.copy_(self.weights.unsqueeze(0).repeat(feat_dim, 1))

#     def _compute_weights(self, dim_in, k):
#         """
#         Compute the fixed weights for the projection head.
#         """
#         # trial 16
#         weights = torch.tensor([1 / (k ** i) for i in range(dim_in)], dtype=torch.float32) 
        
#         # trial 17
#         # weights = torch.tensor([1 / (k ** (dim_in - 1 - i)) for i in range(dim_in)], dtype=torch.float32) 
        
#         # trial 18
#         # center = (dim_in - 1) / 2  # Center index
#         # sigma = dim_in / 6  # Standard deviation (covers ~99.7% for 3σ)
#         # indices = torch.arange(dim_in, dtype=torch.float32)  # Indices: 0, 1, ..., dim_in-1
#         # weights = torch.exp(-((indices - center) ** 2) / (2 * sigma ** 2))  # Gaussian formula
#         # weights = weights / weights.max()  # Normalize weights to [0, 1]

#         # trial 19: k = 1
#         # weights = torch.tensor([1 / (1 ** i) for i in range(dim_in)], dtype=torch.float32)
#         return weights

#     def forward(self, r):
#         """
#         Apply fixed weighting to the input features.
#         """
#         device = r.device  # Ensure weights are on the same device as input
#         weights = self.weights.to(device)  # Move weights to the same device
#         weighted_features = r * weights
#         return weighted_features #r * weights

class FixedWeightingHead(nn.Module):
    """
    A fixed linear transform (reweighting) from dim_in -> feat_dim.
    Weights are not trainable (frozen).
    """
    def __init__(self, dim_in, feat_dim, bias=True, init_scale=1.0):
        super(FixedWeightingHead, self).__init__()
        
        # Create a parameter for weight, shape [feat_dim, dim_in].
        # We'll immediately freeze it.
        W = torch.empty(feat_dim, dim_in)
        nn.init.kaiming_uniform_(W, a=math.sqrt(5))
        W *= init_scale  # optional scaling factor

        self.weight = nn.Parameter(W, requires_grad=False)
        
        if bias:
            b = torch.zeros(feat_dim)
            self.bias = nn.Parameter(b, requires_grad=False)
        else:
            self.bias = None

    def forward(self, x):
        # x: (batch_size, dim_in)
        # F.linear does (x @ self.weight.T) + bias
        out = F.linear(x, self.weight, self.bias)
        return out

class SimSiam(nn.Module):
    def __init__(self, name='resnet50', projector_dim=2048, predictor_dim=512):
        super(SimSiam, self).__init__()
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        
        # Projector: 2-layer MLP
        self.projector = nn.Sequential(
            nn.Linear(dim_in, projector_dim),
            nn.BatchNorm1d(projector_dim),
            nn.ReLU(),
            nn.Linear(projector_dim, projector_dim)
        )
        
        # Predictor: 2-layer MLP
        self.predictor = nn.Sequential(
            nn.Linear(projector_dim, predictor_dim),
            nn.BatchNorm1d(predictor_dim),
            nn.ReLU(),
            nn.Linear(predictor_dim, projector_dim)
        )

    def forward(self, x1, x2):
        # Forward pass through encoder and projector
        h1 = self.encoder(x1)
        h2 = self.encoder(x2)
        z1, z2 = self.projector(h1), self.projector(h2)
        
        # Predictor applied to z1 and z2
        p1, p2 = self.predictor(z1), self.predictor(z2)
        
        return h1, h2, z1, z2, p1, p2

class BarlowTwinsModel(nn.Module):
    def __init__(self, name='resnet18', projector_dim=512, hidden_dim=1024):
        super(BarlowTwinsModel, self).__init__()

        # Backbone
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        
        # 3-layer projector MLP (BN on all layers, no ReLU on the last, no affine in final BN)
        self.projector = nn.Sequential(
            nn.Linear(dim_in, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, projector_dim, bias=False),
            nn.BatchNorm1d(projector_dim, affine=False)
        )

    def forward(self, x):
        # x: [batch_size * 2, channels, height, width]
        feats = self.encoder(x)
        z = self.projector(feats)
        return z


class DirectDLRModel(nn.Module):
    def __init__(self, name='resnet18', sub_dim=128):
        """
        DirectDLR model without a projection head.
        
        Args:
            encoder: Any encoder backbone (e.g., ResNet18/50), should return flat feature vector.
            sub_dim: Dimension of the sub-vector used for the loss (first k dims).
        """
        super(DirectDLRModel, self).__init__()
        
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        self.sub_dim = sub_dim

    def forward(self, x, return_subvector_only=False):
        """
        Forward pass.
        
        Args:
            x: Input tensor [B, C, H, W]
            return_subvector_only: If True, returns only first sub_dim dimensions for loss
        
        Returns:
            h: Encoder output [B, D] or [B, sub_dim] depending on flag
        """
        h = self.encoder(x)  # assume shape [B, D]

        if return_subvector_only:
            return h[:, :self.sub_dim]
        return h

